import os
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

# Import necessary custom classes from elsewhere in your codebase if needed
from datasets.cifar10 import SequentialCIFARWrapper  # Assuming you have this wrapper implemented for sequential data
from datasets.cifar100 import SequentialCIFARWrapper  # Assuming similar wrapper for CIFAR-100
from torchvision.transforms import autoaugment
from torchvision.transforms.functional import InterpolationMode
from utils.data_augmentation import RandomMixup, RandomCutmix, ClassificationPresetTrain

def get_sequential_data_loaders(args, num_steps=10):
    """
    Function to get train and test data loaders for sequential CIFAR10 or CIFAR100 based on input arguments.

    Args:
        args: Parsed arguments containing training options (e.g., batch size, dataset name, transformations).
        num_steps (int): Number of steps (frames) in the sequence for each sample.

    Returns:
        train_loader, test_loader: DataLoader instances for training and testing.
    """

    # Define transformations based on the dataset
    if args.class_num == 10:
        # CIFAR-10 transformation presets
        transform_train = ClassificationPresetTrain(
            mean=(0.4914, 0.4822, 0.4465),
            std=(0.2023, 0.1994, 0.2010),
            interpolation=InterpolationMode.BILINEAR,
            auto_augment_policy='ta_wide',
            random_erase_prob=0.1
        )
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    elif args.class_num == 100:
        # CIFAR-100 transformation presets
        transform_train = ClassificationPresetTrain(
            mean=(0.5071, 0.4865, 0.4409),
            std=(0.2673, 0.2564, 0.2761),
            interpolation=InterpolationMode.BILINEAR,
            auto_augment_policy='ta_wide',
            random_erase_prob=0.1
        )
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
        ])
    else:
        raise NotImplementedError(f"Dataset class number {args.class_num} not supported.")

    # Load the dataset
    if args.class_num == 10:
        base_train_dataset = torchvision.datasets.CIFAR10(
            root=args.data_dir,
            train=True,
            transform=None,
            download=True
        )
        base_test_dataset = torchvision.datasets.CIFAR10(
            root=args.data_dir,
            train=False,
            transform=None,
            download=True
        )

        # Wrap datasets in SequentialCIFARWrapper to make them sequential
        train_dataset = SequentialCIFARWrapper(base_train_dataset, num_steps=num_steps, transform=transform_train)
        test_dataset = SequentialCIFARWrapper(base_test_dataset, num_steps=num_steps, transform=transform_test)

    elif args.class_num == 100:
        base_train_dataset = torchvision.datasets.CIFAR100(
            root=args.data_dir,
            train=True,
            transform=None,
            download=True
        )
        base_test_dataset = torchvision.datasets.CIFAR100(
            root=args.data_dir,
            train=False,
            transform=None,
            download=True
        )

        # Wrap datasets in SequentialCIFARWrapper to make them sequential
        train_dataset = SequentialCIFARWrapper(base_train_dataset, num_steps=num_steps, transform=transform_train)
        test_dataset = SequentialCIFARWrapper(base_test_dataset, num_steps=num_steps, transform=transform_test)

    else:
        raise NotImplementedError(f"Dataset class number {args.class_num} not supported.")

    # Setup MixUp and CutMix transformations if needed
    mixup_transforms = []
    if args.mixup:
        mixup_transforms.append(RandomMixup(args.class_num, p=1.0, alpha=0.2))
    if args.cutmix:
        mixup_transforms.append(RandomCutmix(args.class_num, p=1.0, alpha=1.0))

    # Combine MixUp and CutMix using RandomChoice if both are enabled
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # Custom collate_fn to apply mixup/cutmix
    else:
        collate_fn = default_collate  # Use default collate if no mixup/cutmix

    # DataLoader for training
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.b,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
        num_workers=args.j,
        pin_memory=True
    )

    # DataLoader for testing
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.b,
        shuffle=False,
        drop_last=False,
        num_workers=args.j,
        pin_memory=True
    )

    return train_loader, test_loader

